{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imputation Experiment on MUJOCO Dataset\n",
    "\n",
    "Forked from [SSSD repo](https://github.com/AI4HealthUOL/SSSD) : \n",
    "\n",
    "We collected the dataset directly from [NRTSI repository](https://github.com/lupalab/NRTSI/tree/main/codes_regularly-sampled), which provides a [link](https://www.dropbox.com/s/pjccc2piis8g2fx/mujoco_train.npy?dl=0) for the train set, and another [link](https://www.dropbox.com/s/ktkswh77sueqfy8/mujoco_test.npy?dl=0) for the test set.  \n",
    "\n",
    "Shan, Siyuan, Yang Li, and Junier B. Oliva. \"Nrtsi: Non-recurrent time series imputation.\" arXiv preprint arXiv:2102.03340 (2021)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import sys\n",
    "sys.path.append(os.path.join(os.path.dirname('__file__'), '../'))\n",
    "\n",
    "from engine.solver import Trainer\n",
    "from torch.utils.data import Dataset\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from Utils.io_utils import load_yaml_config, instantiate_from_config\n",
    "from Models.interpretable_diffusion.model_utils import normalize_to_neg_one_to_one, unnormalize_to_zero_to_one"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_mask(observed_values, missing_ratio=0.1, seed=1984, exclude_features=None):\n",
    "    observed_masks = ~np.isnan(observed_values)\n",
    "    if exclude_features is not None:\n",
    "        observed_masks[:, exclude_features] = False\n",
    "\n",
    "    # randomly set some percentage as ground-truth\n",
    "    masks = observed_masks.reshape(-1).copy()\n",
    "    obs_indices = np.where(masks)[0].tolist()\n",
    "\n",
    "    # Store the state of the RNG to restore later.\n",
    "    st0 = np.random.get_state()\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    miss_indices = np.random.choice(\n",
    "        obs_indices, (int)(len(obs_indices) * missing_ratio), replace=False\n",
    "    )\n",
    "\n",
    "    # Restore RNG.\n",
    "    np.random.set_state(st0)\n",
    "    \n",
    "    masks[miss_indices] = False\n",
    "    gt_masks = masks.reshape(observed_masks.shape)\n",
    "    observed_values = np.nan_to_num(observed_values)\n",
    "    return observed_values, observed_masks, gt_masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MUJOCODataset(Dataset):\n",
    "    def __init__(self, data, regular=True, ratio=0.):\n",
    "        super(MUJOCODataset, self).__init__()\n",
    "        self.sample_num = data.shape[0]\n",
    "        self.samples = data\n",
    "        self.regular = regular\n",
    "        self.mask = np.empty([0, data.shape[1], data.shape[2]])\n",
    "        if not self.regular:\n",
    "            for i in range(data.shape[0]):\n",
    "                *_, mask = random_mask(data[i, :, :], ratio)\n",
    "                self.mask = np.row_stack([self.mask, np.expand_dims(mask, 0)])\n",
    "        self.mask = self.mask.astype(bool)\n",
    "\n",
    "    def __getitem__(self, ind):\n",
    "        x = self.samples[ind, :, :]\n",
    "        if self.regular:\n",
    "            return torch.from_numpy(x).float()\n",
    "        mask = self.mask[ind, :, :]\n",
    "        return torch.from_numpy(x).float(), torch.from_numpy(mask)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.sample_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(8000, 100, 14) (2000, 100, 14)\n"
     ]
    }
   ],
   "source": [
    "train = np.load('../Data/mujoco_train.npy')  # downloaded from https://www.dropbox.com/s/pjccc2piis8g2fx/mujoco_train.npy?dl=0\n",
    "test = np.load('../Data/mujoco_test.npy')  # downloaded from https://www.dropbox.com/s/ktkswh77sueqfy8/mujoco_test.npy?dl=0\n",
    "\n",
    "print(train.shape, test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = MinMaxScaler()\n",
    "train_scaled = normalize_to_neg_one_to_one(scaler.fit_transform(train.reshape(-1, train.shape[-1]))).reshape(train.shape)\n",
    "test_scaled = scaler.transform(test.reshape(-1, test.shape[-1])).reshape(test.shape)\n",
    "test_scaled = normalize_to_neg_one_to_one(test_scaled)\n",
    "\n",
    "train_dataset = MUJOCODataset(train_scaled)\n",
    "dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0, pin_memory=True, sampler=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args_Example:\n",
    "    def __init__(self) -> None:\n",
    "        self.config_path = '../Config/mujoco_sssd.yaml'\n",
    "        self.save_dir = '../imputation_exp'\n",
    "        self.gpu = 0\n",
    "        os.makedirs(self.save_dir, exist_ok=True)\n",
    "\n",
    "args =  Args_Example()\n",
    "configs = load_yaml_config(args.config_path)\n",
    "device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "model = instantiate_from_config(configs['model']).to(device)\n",
    "trainer = Trainer(config=configs, args=args, model=model, dataloader={'dataloader':dataloader})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0150f84ae83a45a78fc6da2a7b633782",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/12000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training complete\n"
     ]
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "614c47ff9d7f45038082445a519fea22",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f6d3aec9f5a45de8715e3b8e2c88dbb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cf94050d7c4c4d11915b700b80798090",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "71056bafb07e4079ab73e71165ea4e3c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4abb9c73f46949bfba05b439d9f29bfb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ead26853a09b4caaa7442d86849004c2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb34ca773efe420fabe6e120ea0f1ed7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e34117ad9adf41a28691e4db5136c756",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Now with 0.7 unobserved: 0.00027023070656259114\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2886f42d2d0b47238ceb12539b0174f6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "103db6404fb94cb7addf5bfa97a16e50",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dc742bdfe56d47d7a49eceb2d244e34a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7c53ba5668394840a4f80e07a09e0f3b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "69ff9cfb37a14d24aa06e0b10d9e6acf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1a9c90b5a067400f861faa9f3eaba541",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fe8cfac6b167497b87e3c70db0ed10ae",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "18e9544711b7482498dc442d0c8ff809",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Now with 0.8 unobserved: 0.000321200980367213\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "64b688b67689468f8d910085514b51be",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "906a3eeabfc44f259f6b72a7c72bb6c3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2eadc930eee9449aa5749587d8cf1422",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "75bb526f88274d0c8f6ed9b98abc1094",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4ab8f9d143734d9d9113bf337235d15a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "389add770af04812a74db7b41f564dc4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cf26cc7ec95d43c28153cc28005cc820",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "654cd333d76e46e6a7dc914d116d577a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "conditional sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Now with 0.9 unobserved: 0.0005406215348419648\n"
     ]
    }
   ],
   "source": [
    "sample_num, seq_length, feat_num = test_scaled.shape\n",
    "\n",
    "for missing_ratio in [0.7, 0.8, 0.9]:\n",
    "    mses = []\n",
    "    samples = np.empty([0, sample_num, seq_length, feat_num])\n",
    "    test_dataset = MUJOCODataset(test_scaled, regular=False, ratio=missing_ratio)\n",
    "    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=0, pin_memory=True, sampler=None)\n",
    "\n",
    "    sample, *_ = trainer.restore(test_dataloader, shape=[seq_length, feat_num], coef=1e-2, stepsize=5e-2, sampling_steps=200)\n",
    "    sample = scaler.inverse_transform(unnormalize_to_zero_to_one(sample.reshape(-1, feat_num))).reshape(sample.shape)\n",
    "    samples = np.row_stack([samples, np.expand_dims(sample, 0)])\n",
    "    mask = test_dataset.mask\n",
    "    mse = mean_squared_error(sample[~mask], test[~mask])\n",
    "    print(f'Now with {missing_ratio} unobserved: {mse}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "DPlearning_3.8",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
